/*
 * Copyright (c) 2011-2022, baomidou (jobob@qq.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.baomidou.mybatisplus.core.injector.methods;

import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.core.enums.SqlMethod;
import com.baomidou.mybatisplus.core.injector.AbstractMethod;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.core.toolkit.sql.SqlScriptUtils;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/**
 * 批量更新
 * @author wanglei
 * @since 2022-03-23
 */
public class UpdateBatch extends AbstractMethod {

    /**
     * key 数据库类型 value 数据库批量插入sql生成方法
     */
    private static Map<DbType, GenUpdateSql> DB_FUNCTIONS = new HashMap<>();

    /**
     * 无法为用户自动生成批量插入方法的数据库类型，给用户提示
     */
    public static final Set<DbType> WARN_SET = new HashSet<>();

    public UpdateBatch() {
        super(SqlMethod.UPDATE_BATCH.getMethod());
    }

    /**
     * @param name 方法名
     * @since 3.5.0
     */
    public UpdateBatch(String name) {
        super(name);
    }

    @Override
    public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        SqlMethod sqlMethod = SqlMethod.UPDATE_BATCH;
        String sql = "";
        //碰到不支持的数据库，给一个warn给用户
        if (!DB_FUNCTIONS.containsKey(tableInfo.getDbType()) && !WARN_SET.contains(tableInfo.getDbType())) {
            logger.warn("db type: " + tableInfo.getDbType().getDb() + " does not support batch update");
            WARN_SET.add(tableInfo.getDbType());
        } else if (DB_FUNCTIONS.containsKey(tableInfo.getDbType())) {
            sql = DB_FUNCTIONS.get(tableInfo.getDbType()).genSql(mapperClass, modelClass, tableInfo);
        }
        sql = String.format(sqlMethod.getSql(), sql);
        SqlSource sqlSource = languageDriver.createSqlSource(configuration, sql, modelClass);
        return this.addUpdateMappedStatement(mapperClass, modelClass, getMethod(sqlMethod), sqlSource);
    }

    /**
     * 生成 mysql 批量更新 sql
     * 采用udpate + join方式，性能最好 如下：
     * update user a join(
     * select 1 as user_id,'王磊' as name,30 as age,'男' as sex,2 as school_id
     * ) b USING(user_id)
     * set a.name = b.name,
     * a.age = b.age ,
     * a.sex=b.sex,
     * a.school_id=b.school_id
     *
     * @param mapperClass mapper类
     * @param modelClass  model类
     * @param tableInfo   表信息
     * @return sql数据
     */
    public String genMysqlUpdateBatchSql(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        StringBuilder sql = new StringBuilder(UPDATE);
        sql.append(tableInfo.getTableName()).append("  a ").append(JOIN).append(LEFT_BRACKET);
        String selectColumns = tableInfo.getAllInsertSqlColumn(null);
        //col1,col2,col3,
        String[] columns = selectColumns.split(COMMA);
        //item.col1,item.col2,item.col3,  对应的字段
        String sqlProperty = tableInfo.getAllInsertSqlProperty(Constants.ITEM + DOT);
        String[] sqlPropertys = sqlProperty.split(RIGHT_BRACE + COMMA);
        //拼接为  item.col1 as col1,item.col2 as col2,item.col3 as col3,
        StringBuilder selectValueField = new StringBuilder();
        // 如果是自增的话，拼接id字段
        if(tableInfo.getIdType() == IdType.AUTO){
            selectValueField.append(HASH + LEFT_BRACE + Constants.ITEM + DOT
                + tableInfo.getKeyProperty() + RIGHT_BRACE).append(AS).append(tableInfo.getKeyColumn()).append(COMMA);
        }


        //拼接as

        try{
            for (int i = 0; i < sqlPropertys.length; i++) {
                if (StringUtils.isNotBlank(sqlPropertys[i])) {
                    selectValueField.append(sqlPropertys[i]).append(RIGHT_BRACE).append(AS).append(columns[i]).append(COMMA);
                }
            }
        }catch (Exception e){
            e.printStackTrace();
        }


        String selectFrom = SqlScriptUtils.convertForeach( SELECT + SqlScriptUtils.convertTrim(selectValueField.toString(),
            null, null, null, COMMA), Constants.COLLECTION, Constants.INDEX, Constants.ITEM, UNION_ALL);
        sql.append(selectFrom);
        sql.append(RIGHT_BRACKET).append("  b ").append(USING).append(LEFT_BRACKET).append(tableInfo.getKeyColumn()).append(RIGHT_BRACKET);
        sql.append(SET);
        String tableA = "a.";
        String tableB = "b.";
        StringBuilder set = new StringBuilder();
        for (String column : columns) {
            set.append(tableA).append(column).append(EQUALS).append(tableB).append(column).append(COMMA);
        }
        sql.append(SqlScriptUtils.convertTrim(set.toString(),
            null, null, null, COMMA));
        return sql.toString();
    }

    /**
     * 生成 Postgresql 批量更新 sql
     *
     * @param mapperClass mapper类
     * @param modelClass  model类
     * @param tableInfo   表信息
     * @return sql数据
     */
    public String genPostgresqlUpdateBatchSql(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        // update table
        StringBuilder sql = new StringBuilder(UPDATE).append(tableInfo.getTableName());
        //SqlScriptUtils.convertSet()
        String selectColumns = tableInfo.getAllInsertSqlColumn(null);
        //col1,col2,col3,
        String[] columns = selectColumns.split(COMMA);
        StringBuilder set = new StringBuilder();
        String vtable = " vtable";
        for (String column : columns) {
            set.append(column).append(EQUALS).append(vtable).append(DOT).append(column.trim()).append(COMMA);
        }
        // set col1 = vtable.col1,col2 = vtable.col2,col3 = vtable.col3
        sql.append(SqlScriptUtils.convertSet(SqlScriptUtils.convertTrim(set.toString(),null,null,null,COMMA)));
        sql.append(FROM).append(LEFT_BRACKET).append(VALUES);

        //item.col1,item.col2,item.col3,  对应的字段
        String sqlProperty = tableInfo.getAllInsertSqlProperty(Constants.ITEM + DOT);
        String[] sqlPropertys = sqlProperty.split(RIGHT_BRACE + COMMA);
        //拼接为  item.col1 as col1,item.col2 as col2,item.col3 as col3,
        StringBuilder selectValueField = new StringBuilder();
        // 如果是自增的话，拼接id字段
        if(tableInfo.getIdType() == IdType.AUTO){
            selectValueField.append(HASH + LEFT_BRACE + Constants.ITEM + DOT+ tableInfo.getKeyProperty() + RIGHT_BRACE);
        }
        // 为了处理null 报错，这里使用 choose  when 处理
        for (int i = 0; i < sqlPropertys.length; i++) {
            if (StringUtils.isNotBlank(sqlPropertys[i])) {
                selectValueField.append(SqlScriptUtils.convertChoose(parsePropertyPostgresqlNotNull(sqlPropertys[i]),sqlPropertys[i],"null")).append(COMMA);
            }
        }
        sql.append(SqlScriptUtils.convertForeach( SqlScriptUtils.convertTrim(selectValueField.toString(),
            StringPool.LEFT_BRACKET,StringPool.RIGHT_BRACKET,null,COMMA),Constants.COLLECTION,Constants.INDEX,Constants.ITEM,COMMA));
        sql.append(RIGHT_BRACKET).append(AS).append(vtable).append(LEFT_BRACKET);
        if(tableInfo.getIdType() == IdType.AUTO){
            selectColumns = tableInfo.getKeyColumn() + COMMA + selectColumns;
        }
        sql.append(SqlScriptUtils.convertTrim(selectColumns,null,null,null,COMMA));
        sql.append(RIGHT_BRACKET).append(WHERE)
            .append(vtable).append(DOT).append(tableInfo.getKeyColumn()).append(EQUALS)
            .append(tableInfo.getTableName()).append(DOT).append(tableInfo.getKeyColumn());
        return sql.toString();
    }


    /**
     * #{item.col} 转换为 item.col
     * @param property 字段
     * @return 转换后的字段
     */
    private String parsePropertyPostgresqlNotNull(String property) {
        return property.replaceAll("#\\{","").replaceAll("}", "") + "!=null";
    }
    /**
     * 生成 oracle 批量更新 sql
     *
     * @param mapperClass mapper类
     * @param modelClass  model类
     * @param tableInfo   表信息
     * @return sql数据
     */
    public String genOracleUpdateBatchSql(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        // 模板
        String template = "UPDATE %s %s WHERE %s=#{%s} %s";
        final String additional = optlockVersion(tableInfo) + tableInfo.getLogicDeleteSql(true, true);
        String oneRowUpdateSql = String.format(template, tableInfo.getTableName(),
            sqlSet(tableInfo.isWithLogicDelete(), false, tableInfo, false, Constants.ITEM, Constants.ITEM + DOT),
            tableInfo.getKeyColumn(), Constants.ITEM + DOT + tableInfo.getKeyProperty(), additional);
        return SqlScriptUtils.convertForeach( oneRowUpdateSql, Constants.COLLECTION, Constants.INDEX, Constants.ITEM, SEMICOLON,"begin",";end;");
    }

    /**
     * 初始化
     *
     * @return 本身
     */
    public UpdateBatch init() {
        //mysql 测试通过
        DB_FUNCTIONS.put(DbType.MYSQL, this::genMysqlUpdateBatchSql);
        // oracle测试通过
        DB_FUNCTIONS.put(DbType.ORACLE, this::genOracleUpdateBatchSql);
        DB_FUNCTIONS.put(DbType.ORACLE_12C, this::genOracleUpdateBatchSql);

        // postgresql
        DB_FUNCTIONS.put(DbType.POSTGRE_SQL, this::genPostgresqlUpdateBatchSql);
        return this;
    }

    /**
     * 用来给不同的数据库生成不同的sql的接口
     *
     * @author wanglei
     * @since 2022-03-22
     */
    @FunctionalInterface
    public interface GenUpdateSql {
        /**
         * 生成sql
         *
         * @param mapperClass mapper类
         * @param modelClass  model类
         * @param tableInfo   表信息
         * @return sql
         */
        String genSql(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo);
    }

}
